{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "    import cirq\n",
    "except ImportError:\n",
    "    print(\"installing cirq...\")\n",
    "    !pip install --quiet cirq\n",
    "    print(\"installed cirq.\")\n",
    "\n",
    "try:\n",
    "    import quimb\n",
    "except ImportError:\n",
    "    print(\"installing cirq-core[contrib]...\")\n",
    "    !pip install --quiet 'cirq-core[contrib]'\n",
    "    print(\"installed cirq-core[contrib].\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Contract a Grid Circuit\n",
    "Shallow circuits on a planar grid with low-weight observables permit easy contraction."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import networkx as nx\n",
    "\n",
    "import cirq\n",
    "import quimb\n",
    "import quimb.tensor as qtn\n",
    "from cirq.contrib.svg import SVGCircuit\n",
    "\n",
    "import cirq.contrib.quimb as ccq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import seaborn as sns\n",
    "\n",
    "sns.set_style('ticks')\n",
    "\n",
    "plt.rc('axes', labelsize=16, titlesize=16)\n",
    "plt.rc('xtick', labelsize=14)\n",
    "plt.rc('ytick', labelsize=14)\n",
    "plt.rc('legend', fontsize=14, title_fontsize=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# theme colors\n",
    "QBLUE = '#1967d2'\n",
    "QRED = '#ea4335ff'\n",
    "QGOLD = '#fbbc05ff'\n",
    "QGREEN = '#34a853ff'\n",
    "\n",
    "QGOLD2 = '#ffca28'\n",
    "QBLUE2 = '#1e88e5'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make an example circuit topology\n",
    "We'll use entangling gates according to this topology and compute the value of an observable on the red nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "width = 3\n",
    "height = 4\n",
    "graph = nx.grid_2d_graph(width, height)\n",
    "rs = np.random.RandomState(52)\n",
    "nx.set_edge_attributes(\n",
    "    graph, name='weight', values={e: np.round(rs.uniform(), 2) for e in graph.edges}\n",
    ")\n",
    "\n",
    "zz_inds = ((width // 2, (height // 2 - 1)), (width // 2, (height // 2)))\n",
    "nx.draw_networkx(\n",
    "    graph,\n",
    "    pos={n: n for n in graph.nodes},\n",
    "    node_color=[QRED if node in zz_inds else QBLUE for node in graph.nodes],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qubits = [cirq.GridQubit(*n) for n in graph]\n",
    "circuit = cirq.Circuit(\n",
    "    cirq.H.on_each(qubits),\n",
    "    ccq.get_grid_moments(graph),\n",
    "    cirq.Moment([cirq.rx(0.456).on_each(qubits)]),\n",
    ")\n",
    "SVGCircuit(circuit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Observable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ZZ = cirq.Z(cirq.GridQubit(*zz_inds[0])) * cirq.Z(cirq.GridQubit(*zz_inds[1]))\n",
    "ZZ"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The contraction\n",
    "The value of the observable is $\\langle 0 | U^\\dagger (ZZ) U |0 \\rangle$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tot_c = ccq.circuit_for_expectation_value(circuit, ZZ)\n",
    "SVGCircuit(tot_c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## We can simplify the circuit\n",
    "By cancelling the \"forwards\" and \"backwards\" part of the circuit that are outside of the light-cone of the observable, we can reduce the number of gates to consider --- and sometimes the number of qubits involved at all. To see this in action, run the following cell and then keep re-running the following cell to watch gates disappear from the circuit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compressed_c = tot_c.copy()\n",
    "print(len(list(compressed_c.all_operations())), len(compressed_c.all_qubits()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**(try re-running the following cell to watch the circuit get smaller)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "compressed_c = cirq.merge_k_qubit_unitaries(compressed_c, k=2)\n",
    "compressed_c = cirq.merge_k_qubit_unitaries(compressed_c, k=1)\n",
    "\n",
    "compressed_c = cirq.drop_negligible_operations(compressed_c, atol=1e-6)\n",
    "compressed_c = cirq.drop_empty_moments(compressed_c)\n",
    "print(len(list(compressed_c.all_operations())), len(compressed_c.all_qubits()))\n",
    "SVGCircuit(compressed_c)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Utility function to fully-simplify\n",
    "\n",
    "We provide this utility function to fully simplify a circuit."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ccq.simplify_expectation_value_circuit(tot_c)\n",
    "SVGCircuit(tot_c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# simplification might eliminate qubits entirely for large graphs and\n",
    "# shallow `p`, so re-get the current qubits.\n",
    "qubits = sorted(tot_c.all_qubits())\n",
    "print(len(qubits))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Turn it into a Tensor Netowork\n",
    "\n",
    "We explicitly \"cap\" the tensor network with `<0..0|` bras so the entire thing contracts to the expectation value $\\langle 0 | U^\\dagger (ZZ) U |0 \\rangle$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensors, qubit_frontier, fix = ccq.circuit_to_tensors(circuit=tot_c, qubits=qubits)\n",
    "end_bras = [\n",
    "    qtn.Tensor(data=quimb.up().squeeze(), inds=(f'i{qubit_frontier[q]}_q{q}',), tags={'Q0', 'bra0'})\n",
    "    for q in qubits\n",
    "]\n",
    "\n",
    "tn = qtn.TensorNetwork(tensors + end_bras)\n",
    "tn.graph(color=['Q0', 'Q1', 'Q2'])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `rank_simplify` effectively folds together 1- and 2-qubit gates\n",
    "\n",
    "In practice, using this is faster than running the circuit optimizer to remove gates that cancel themselves, but please benchmark for your particular use case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tn.rank_simplify(inplace=True)\n",
    "tn.graph(color=['Q0', 'Q1', 'Q2'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The tensor contraction path tells us how expensive this will be"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_info = tn.contract(get='path-info')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_info.opt_cost / int(3e9)  # assuming 3gflop, in seconds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "path_info.largest_intermediate * 128 / 8 / 1024 / 1024 / 1024  # gb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Do the contraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "zz = tn.contract(inplace=True)\n",
    "zz = np.real_if_close(zz)\n",
    "print(zz)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Big Circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "width = 8\n",
    "height = 8\n",
    "graph = nx.grid_2d_graph(width, height)\n",
    "rs = np.random.RandomState(52)\n",
    "nx.set_edge_attributes(\n",
    "    graph, name='weight', values={e: np.round(rs.uniform(), 2) for e in graph.edges}\n",
    ")\n",
    "\n",
    "zz_inds = ((width // 2, (height // 2 - 1)), (width // 2, (height // 2)))\n",
    "nx.draw_networkx(\n",
    "    graph,\n",
    "    pos={n: n for n in graph.nodes},\n",
    "    node_color=[QRED if node in zz_inds else QBLUE for node in graph.nodes],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qubits = [cirq.GridQubit(*n) for n in graph]\n",
    "circuit = cirq.Circuit(\n",
    "    cirq.H.on_each(qubits),\n",
    "    ccq.get_grid_moments(graph),\n",
    "    cirq.Moment([cirq.rx(0.456).on_each(qubits)]),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ZZ = cirq.Z(cirq.GridQubit(*zz_inds[0])) * cirq.Z(cirq.GridQubit(*zz_inds[1]))\n",
    "ZZ"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ccq.tensor_expectation_value(circuit, ZZ)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
